{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Gwt7z7qdmTbW"
   },
   "outputs": [],
   "source": [
    "# Copyright 2019 NVIDIA Corporation. All Rights Reserved.\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License.\n",
    "# =============================================================================="
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "i4NKCp2VmTbn"
   },
   "source": [
    "<img src=\"http://developer.download.nvidia.com/compute/machine-learning/frameworks/nvidia_logo.png\" style=\"width: 90px; float: right;\">\n",
    "\n",
    "# Kaldi TRTIS Inference Online Demo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "fW0OKDzvmTbt"
   },
   "source": [
    "## Overview\n",
    "\n",
    "\n",
    "This repository provides a wrapper around the online GPU-accelerated ASR pipeline from the paper [GPU-Accelerated Viterbi Exact Lattice Decoder for Batched Online and Offline Speech Recognition](https://arxiv.org/abs/1910.10032). That work includes a high-performance implementation of a GPU HMM Decoder, a low-latency Neural Net driver, fast Feature Extraction for preprocessing, and new ASR pipelines tailored for GPUs. These different modules have been integrated into the Kaldi ASR framework.\n",
    "\n",
    "This repository contains a TensorRT Inference Server custom backend for the Kaldi ASR framework. This custom backend calls the high-performance online GPU pipeline from the Kaldi ASR framework. This TensorRT Inference Server integration provides ease-of-use to Kaldi ASR inference: gRPC streaming server, dynamic sequence batching, and multi-instances support. A client connects to the gRPC server, streams audio by sending chunks to the server, and gets back the inferred text as an answer. More information about the TensorRT Inference Server can be found [here](https://docs.nvidia.com/deeplearning/sdk/tensorrt-inference-server-guide/docs/).  \n",
    "\n",
    "\n",
    "\n",
    "### Learning objectives\n",
    "\n",
    "This notebook demonstrates the steps for carrying out inferencing with the Kaldi TRTIS backend server using a Python gRPC client in an online context, that is, we will stream live audio from a microphone to the inference server and receive the results back.\n",
    "\n",
    "## Content\n",
    "1. [Pre-requisite](#1)\n",
    "1. [Setup](#2)\n",
    "1. [Audio helper classes](#3)\n",
    "1. [Inference](#4)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "aDFrE4eqmTbv"
   },
   "source": [
    "<a id=\"1\"></a>\n",
    "## 1. Pre-requisite\n",
    "\n",
    "\n",
    "### 1.1 Docker containers\n",
    "Follow the steps in [README](README.md) to build Kaldi server and client containers.\n",
    "\n",
    "### 1.2 Hardware\n",
    "This notebook can be executed on any CUDA-enabled NVIDIA GPU, although for efficient mixed precision inference, a [Tensor Core NVIDIA GPU](https://www.nvidia.com/en-us/data-center/tensorcore/) is desired (Volta, Turing or newer architectures). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "k7RLEcKhmTb0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thu Mar  5 00:28:21 2020       \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| NVIDIA-SMI 440.48.02    Driver Version: 440.48.02    CUDA Version: 10.2     |\r\n",
      "|-------------------------------+----------------------+----------------------+\r\n",
      "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\r\n",
      "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\r\n",
      "|===============================+======================+======================|\r\n",
      "|   0  Quadro GV100        Off  | 00000000:05:00.0 Off |                  Off |\r\n",
      "| 32%   42C    P2    28W / 250W |  17706MiB / 32506MiB |      3%      Default |\r\n",
      "+-------------------------------+----------------------+----------------------+\r\n",
      "                                                                               \r\n",
      "+-----------------------------------------------------------------------------+\r\n",
      "| Processes:                                                       GPU Memory |\r\n",
      "|  GPU       PID   Type   Process name                             Usage      |\r\n",
      "|=============================================================================|\r\n",
      "+-----------------------------------------------------------------------------+\r\n"
     ]
    }
   ],
   "source": [
    "!nvidia-smi"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "EQAIszkxmTcT"
   },
   "source": [
    "This notebook also requires access to a microphone. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"2\"></a>\n",
    "## 2 Setup \n",
    "### Import libraries and parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import numpy as np\n",
    "import os\n",
    "import sys\n",
    "from builtins import range\n",
    "from functools import partial\n",
    "import soundfile\n",
    "import pyaudio as pa\n",
    "import soundfile\n",
    "import librosa\n",
    "\n",
    "import grpc\n",
    "from tensorrtserver.api import api_pb2\n",
    "from tensorrtserver.api import grpc_service_pb2\n",
    "from tensorrtserver.api import grpc_service_pb2_grpc\n",
    "import tensorrtserver.api.model_config_pb2 as model_config"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument('-f', '--file', help='Path for input file. First line should contain number of lines to search in')\n",
    "\n",
    "parser.add_argument('-v', '--verbose', action=\"store_true\", required=False, default=False,\n",
    "                    help='Enable verbose output')\n",
    "parser.add_argument('-a', '--async', dest=\"async_set\", action=\"store_true\", required=False,\n",
    "                    default=False, help='Use asynchronous inference API')\n",
    "parser.add_argument('--streaming', action=\"store_true\", required=False, default=False,\n",
    "                    help='Use streaming inference API')\n",
    "parser.add_argument('-m', '--model-name', type=str, required=False, default='kaldi_online' ,\n",
    "                    help='Name of model')\n",
    "parser.add_argument('-x', '--model-version', type=int, required=False, default=1,\n",
    "                    help='Version of model. Default is to use latest version.')\n",
    "parser.add_argument('-b', '--batch-size', type=int, required=False, default=1,\n",
    "                    help='Batch size. Default is 1.')\n",
    "parser.add_argument('-u', '--url', type=str, required=False, default='localhost:8001',\n",
    "                    help='Inference server URL. Default is localhost:8001.')\n",
    "parser.add_argument('--chunk_duration', type=float, required=False,\n",
    "                    default=0.51,\n",
    "                    help=\"duration of the audio chunk for streaming \"\n",
    "                            \"recognition, in seconds\")\n",
    "parser.add_argument('--input_device_id', type=int, required=False,\n",
    "                    default=-1, help='Input device id to use to capture audio')\n",
    "parser.add_argument('--sample_rate', type=int, required=False,\n",
    "                    default=16000, help='Sample rate.')\n",
    "FLAGS = parser.parse_args()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Checking server status\n",
    "\n",
    "We first query the status of the server. The target model is 'kaldi_online'. A successful deployment of the Kaldi TRTIS server should result in output similar to the below.\n",
    "\n",
    "```\n",
    "request_status {\n",
    "  code: SUCCESS\n",
    "  server_id: \"inference:0\"\n",
    "  request_id: 17514\n",
    "}\n",
    "server_status {\n",
    "  id: \"inference:0\"\n",
    "  version: \"1.9.0\"\n",
    "  uptime_ns: 14179155408971\n",
    "  model_status {\n",
    "    key: \"kaldi_online\"\n",
    "...\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "request_status {\n",
      "  code: SUCCESS\n",
      "  server_id: \"inference:0\"\n",
      "  request_id: 6234\n",
      "}\n",
      "server_status {\n",
      "  id: \"inference:0\"\n",
      "  version: \"1.9.0\"\n",
      "  uptime_ns: 4061941924008\n",
      "  model_status {\n",
      "    key: \"kaldi_online\"\n",
      "    value {\n",
      "      config {\n",
      "        name: \"kaldi_online\"\n",
      "        platform: \"custom\"\n",
      "        version_policy {\n",
      "          latest {\n",
      "            num_versions: 1\n",
      "          }\n",
      "        }\n",
      "        max_batch_size: 2200\n",
      "        input {\n",
      "          name: \"WAV_DATA\"\n",
      "          data_type: TYPE_FP32\n",
      "          dims: 8160\n",
      "        }\n",
      "        input {\n",
      "          name: \"WAV_DATA_DIM\"\n",
      "          data_type: TYPE_INT32\n",
      "          dims: 1\n",
      "        }\n",
      "        output {\n",
      "          name: \"TEXT\"\n",
      "          data_type: TYPE_STRING\n",
      "          dims: 1\n",
      "        }\n",
      "        instance_group {\n",
      "          name: \"kaldi_online_0\"\n",
      "          count: 2\n",
      "          gpus: 0\n",
      "          kind: KIND_GPU\n",
      "        }\n",
      "        default_model_filename: \"libkaldi-trtisbackend.so\"\n",
      "        sequence_batching {\n",
      "          max_sequence_idle_microseconds: 5000000\n",
      "          control_input {\n",
      "            name: \"START\"\n",
      "            control {\n",
      "              int32_false_true: 0\n",
      "              int32_false_true: 1\n",
      "            }\n",
      "          }\n",
      "          control_input {\n",
      "            name: \"READY\"\n",
      "            control {\n",
      "              kind: CONTROL_SEQUENCE_READY\n",
      "              int32_false_true: 0\n",
      "              int32_false_true: 1\n",
      "            }\n",
      "          }\n",
      "          control_input {\n",
      "            name: \"END\"\n",
      "            control {\n",
      "              kind: CONTROL_SEQUENCE_END\n",
      "              int32_false_true: 0\n",
      "              int32_false_true: 1\n",
      "            }\n",
      "          }\n",
      "          control_input {\n",
      "            name: \"CORRID\"\n",
      "            control {\n",
      "              kind: CONTROL_SEQUENCE_CORRID\n",
      "              data_type: TYPE_UINT64\n",
      "            }\n",
      "          }\n",
      "          oldest {\n",
      "            max_candidate_sequences: 2200\n",
      "            preferred_batch_size: 256\n",
      "            preferred_batch_size: 512\n",
      "            max_queue_delay_microseconds: 1000\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"acoustic_scale\"\n",
      "          value {\n",
      "            string_value: \"1.0\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"beam\"\n",
      "          value {\n",
      "            string_value: \"10\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"frame_subsampling_factor\"\n",
      "          value {\n",
      "            string_value: \"3\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"fst_rxfilename\"\n",
      "          value {\n",
      "            string_value: \"/data/models/LibriSpeech/HCLG.fst\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"ivector_filename\"\n",
      "          value {\n",
      "            string_value: \"/data/models/LibriSpeech/conf/ivector_extractor.conf\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"lattice_beam\"\n",
      "          value {\n",
      "            string_value: \"7\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"max_active\"\n",
      "          value {\n",
      "            string_value: \"10000\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"max_execution_batch_size\"\n",
      "          value {\n",
      "            string_value: \"512\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"mfcc_filename\"\n",
      "          value {\n",
      "            string_value: \"/data/models/LibriSpeech/conf/mfcc.conf\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"nnet3_rxfilename\"\n",
      "          value {\n",
      "            string_value: \"/data/models/LibriSpeech/final.mdl\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"num_worker_threads\"\n",
      "          value {\n",
      "            string_value: \"40\"\n",
      "          }\n",
      "        }\n",
      "        parameters {\n",
      "          key: \"word_syms_rxfilename\"\n",
      "          value {\n",
      "            string_value: \"/data/models/LibriSpeech/words.txt\"\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "      version_status {\n",
      "        key: 1\n",
      "        value {\n",
      "          ready_state: MODEL_READY\n",
      "          infer_stats {\n",
      "            key: 1\n",
      "            value {\n",
      "              success {\n",
      "                count: 6913\n",
      "                total_time_ns: 233146745257\n",
      "              }\n",
      "              compute {\n",
      "                count: 6913\n",
      "                total_time_ns: 225589026013\n",
      "              }\n",
      "              queue {\n",
      "                count: 6913\n",
      "                total_time_ns: 7398387984\n",
      "              }\n",
      "            }\n",
      "          }\n",
      "          model_execution_count: 6913\n",
      "          model_inference_count: 6913\n",
      "          ready_state_reason {\n",
      "          }\n",
      "          last_inference_timestamp_milliseconds: 13619175935932035456\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "  ready_state: SERVER_READY\n",
      "}\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# Create gRPC stub for communicating with the server\n",
    "channel = grpc.insecure_channel(FLAGS.url)\n",
    "grpc_stub = grpc_service_pb2_grpc.GRPCServiceStub(channel)\n",
    "\n",
    "# Prepare request for Status gRPC\n",
    "request = grpc_service_pb2.StatusRequest(model_name=FLAGS.model_name)\n",
    "# Call and receive response from Status gRPC\n",
    "response = grpc_stub.Status(request)\n",
    "\n",
    "print(response)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Testing microphone\n",
    "\n",
    "We next identify the input devices in the system. You will need to select a relevant input device amongst the ones listed. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Input Devices:\n",
      "0: HDA Intel PCH: ALC1150 Analog (hw:0,0)\n",
      "1: HDA Intel PCH: ALC1150 Digital (hw:0,1)\n",
      "2: HDA Intel PCH: ALC1150 Alt Analog (hw:0,2)\n",
      "3: HD Pro Webcam C920: USB Audio (hw:1,0)\n",
      "4: HDA NVidia: HDMI 0 (hw:2,3)\n",
      "5: HDA NVidia: HDMI 2 (hw:2,8)\n",
      "6: HDA NVidia: HDMI 3 (hw:2,9)\n",
      "7: sysdefault\n",
      "8: front\n",
      "9: surround21\n",
      "10: surround40\n",
      "11: surround41\n",
      "12: surround50\n",
      "13: surround51\n",
      "14: surround71\n",
      "15: iec958\n",
      "16: spdif\n",
      "17: default\n",
      "18: dmix\n",
      "Enter device id to use: 3\n"
     ]
    }
   ],
   "source": [
    "import pyaudio\n",
    "import wave\n",
    "\n",
    "p = pyaudio.PyAudio()  # Create an interface to PortAudio\n",
    "\n",
    "device_info = p.get_host_api_info_by_index(0)\n",
    "num_devices = device_info.get('deviceCount')\n",
    "\n",
    "devices = {}\n",
    "for i in range(0, num_devices):\n",
    "    #if (p.get_device_info_by_host_api_device_index(0, i).get(\n",
    "    #    'maxInputChannels')) > 0:\n",
    "        devices[i] = p.get_device_info_by_host_api_device_index(\n",
    "            0, i)\n",
    "\n",
    "if (len(devices) == 0):\n",
    "    raise RuntimeError(\"Cannot find any valid input devices\")\n",
    "\n",
    "\n",
    "print(\"\\nInput Devices:\")\n",
    "for id, info in devices.items():\n",
    "    print(\"{}: {}\".format(id,info.get(\"name\")))\n",
    "input_device_id = int(input(\"Enter device id to use: \"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then employ the selected device, record from it and play back to verify that everything is in order."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device info:\n",
      "{   'defaultHighInputLatency': 0.048,\n",
      "    'defaultHighOutputLatency': -1.0,\n",
      "    'defaultLowInputLatency': 0.01196875,\n",
      "    'defaultLowOutputLatency': -1.0,\n",
      "    'defaultSampleRate': 32000.0,\n",
      "    'hostApi': 0,\n",
      "    'index': 3,\n",
      "    'maxInputChannels': 2,\n",
      "    'maxOutputChannels': 0,\n",
      "    'name': 'HD Pro Webcam C920: USB Audio (hw:1,0)',\n",
      "    'structVersion': 2}\n",
      "Recording\n",
      "Finished recording\n"
     ]
    }
   ],
   "source": [
    "import pprint\n",
    "pp = pprint.PrettyPrinter(indent=4)\n",
    "    \n",
    "print(\"Device info:\")\n",
    "devinfo = p.get_device_info_by_index(input_device_id)  # Or whatever device you care about.\n",
    "pp.pprint(devinfo)\n",
    "\n",
    "chunk = 1024  # Record in chunks of 1024 samples\n",
    "sample_format = pyaudio.paInt16  # 16 bits per sample\n",
    "channels = 1\n",
    "fs = devinfo['defaultSampleRate']  # Record at device default sampling rate\n",
    "seconds = 3\n",
    "filename = \"test.wav\"\n",
    "\n",
    "print('Recording')\n",
    "\n",
    "stream = p.open(format=sample_format,\n",
    "                channels=channels,\n",
    "                rate=int(devinfo[\"defaultSampleRate\"]),\n",
    "                frames_per_buffer=chunk,\n",
    "                input=True,\n",
    "                input_device_index=input_device_id)\n",
    "\n",
    "frames = []  # Initialize array to store frames\n",
    "\n",
    "# Store data in chunks for 3 seconds\n",
    "for i in range(0, int(fs / chunk * seconds)):\n",
    "    data = stream.read(chunk)\n",
    "    frames.append(data)\n",
    "\n",
    "# Stop and close the stream \n",
    "stream.stop_stream()\n",
    "stream.close()\n",
    "# Terminate the PortAudio interface\n",
    "# p.terminate()\n",
    "\n",
    "print('Finished recording')\n",
    "\n",
    "# Save the recorded data as a WAV file\n",
    "wf = wave.open(filename, 'wb')\n",
    "wf.setnchannels(channels)\n",
    "wf.setsampwidth(p.get_sample_size(sample_format))\n",
    "wf.setframerate(fs)\n",
    "wf.writeframes(b''.join(frames))\n",
    "wf.close()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import IPython.display as ipd\n",
    "ipd.Audio(filename)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "RL8d9IwzmTcV"
   },
   "source": [
    "<a id=\"3\"></a>\n",
    "## 3. Audio helper classes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "o6wayGf1mTcX"
   },
   "source": [
    "Next, we define some helper classes for pre-processing audio. The below AudioSegment class takes audio signal and converts the sampling rate to that required by the Kaldi ASR model, which is 16000Hz by default.\n",
    "\n",
    "Note:  For historical reasons, Kaldi expects waveforms in the range (2^15-1)x[-1, 1], not the usual default DSP range [-1, 1]. Therefore, we scale the audio signal by a factor of (2^15-1)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "WAV_SCALE_FACTOR = 2**15-1\n",
    "\n",
    "class AudioSegment(object):\n",
    "    \"\"\"Monaural audio segment abstraction.\n",
    "    :param samples: Audio samples [num_samples x num_channels].\n",
    "    :type samples: ndarray.float32\n",
    "    :param sample_rate: Audio sample rate.\n",
    "    :type sample_rate: int\n",
    "    :raises TypeError: If the sample data type is not float or int.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, samples, sample_rate, target_sr=16000, trim=False,\n",
    "                 trim_db=60):\n",
    "        \"\"\"Create audio segment from samples.\n",
    "        Samples are convert float32 internally, with int scaled to [-1, 1].\n",
    "        \"\"\"\n",
    "        samples = self._convert_samples_to_float32(samples)\n",
    "        if target_sr is not None and target_sr != sample_rate:\n",
    "            samples = librosa.core.resample(samples, sample_rate, target_sr)\n",
    "            sample_rate = target_sr\n",
    "        if trim:\n",
    "            samples, _ = librosa.effects.trim(samples, trim_db)\n",
    "        self._samples = samples\n",
    "        self._sample_rate = sample_rate\n",
    "        if self._samples.ndim >= 2:\n",
    "            self._samples = np.mean(self._samples, 1)\n",
    "\n",
    "    @staticmethod\n",
    "    def _convert_samples_to_float32(samples):\n",
    "        \"\"\"Convert sample type to float32.\n",
    "        Audio sample type is usually integer or float-point.\n",
    "        Integers will be scaled to [-1, 1] in float32.\n",
    "        \"\"\"\n",
    "        float32_samples = samples.astype('float32')\n",
    "        if samples.dtype in np.sctypes['int']:\n",
    "            bits = np.iinfo(samples.dtype).bits\n",
    "            float32_samples *= (1. / ((2 ** (bits - 1)) - 1))\n",
    "        elif samples.dtype in np.sctypes['float']:\n",
    "            pass\n",
    "        else:\n",
    "            raise TypeError(\"Unsupported sample type: %s.\" % samples.dtype)\n",
    "        return WAV_SCALE_FACTOR * float32_samples\n",
    "\n",
    "    @classmethod\n",
    "    def from_file(cls, filename, target_sr=16000, offset=0, duration=0,\n",
    "                 min_duration=0, trim=False):\n",
    "        \"\"\"\n",
    "        Load a file supported by librosa and return as an AudioSegment.\n",
    "        :param filename: path of file to load\n",
    "        :param target_sr: the desired sample rate\n",
    "        :param int_values: if true, load samples as 32-bit integers\n",
    "        :param offset: offset in seconds when loading audio\n",
    "        :param duration: duration in seconds when loading audio\n",
    "        :return: numpy array of samples\n",
    "        \"\"\"\n",
    "        with sf.SoundFile(filename, 'r') as f:\n",
    "            dtype_options = {'PCM_16': 'int16', 'PCM_32': 'int32', 'FLOAT': 'float32'}\n",
    "            dtype_file = f.subtype\n",
    "            if dtype_file in dtype_options:\n",
    "                dtype = dtype_options[dtype_file]\n",
    "            else:\n",
    "                dtype = 'float32'\n",
    "            sample_rate = f.samplerate\n",
    "            if offset > 0:\n",
    "                f.seek(int(offset * sample_rate))\n",
    "            if duration > 0:\n",
    "                samples = f.read(int(duration * sample_rate), dtype=dtype)\n",
    "            else:\n",
    "                samples = f.read(dtype=dtype)\n",
    "\n",
    "        num_zero_pad = int(target_sr * min_duration - samples.shape[0])\n",
    "        if num_zero_pad > 0:\n",
    "            samples = np.pad(samples, [0, num_zero_pad], mode='constant')\n",
    "\n",
    "        samples = samples.transpose()\n",
    "        return cls(samples, sample_rate, target_sr=target_sr, trim=trim)\n",
    "\n",
    "    @property\n",
    "    def samples(self):\n",
    "        return self._samples.copy()\n",
    "\n",
    "    @property\n",
    "    def sample_rate(self):\n",
    "        return self._sample_rate"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<a id=\"4\"></a>\n",
    "## Inference\n",
    "\n",
    "We first create an inference context object that connects to the Kaldi TRTIS servier via a gPRC connection.\n",
    "\n",
    "The server expects chunks of audio each containing up to input.WAV_DATA.dims samples (default: 8160). Per default, this corresponds to 510ms of audio per chunk (i.e. 16000Hz sampling rate). The last chunk can send a partial chunk smaller than this maximum value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorrtserver.api import *\n",
    "protocol = ProtocolType.from_str(\"grpc\")\n",
    "\n",
    "CORRELATION_ID = 11101\n",
    "ctx = InferContext(FLAGS.url, protocol, FLAGS.model_name, FLAGS.model_version,\n",
    "                    correlation_id=CORRELATION_ID, verbose=True,\n",
    "                    streaming=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we take chunks of audio (each 510ms in duration, containing 8160 samples) from the microphone and stream them sequentially to the Kaldi server. The server processes each chunk as soon as it is received. \n",
    "\n",
    "Unlike data from a .wav file, as we take the data continuoulsy from the mic, there is no `end` marker. Therefore, we receive the result once every 10 chunks. Note that the server will reset it status once the result is sent out.   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "class TranscribeFromMicrophone:\n",
    "\n",
    "    def __init__(self,input_device_id, target_sr, chunk_duration):\n",
    "\n",
    "        self.recording_state = \"init\"\n",
    "        self.target_sr  = target_sr\n",
    "        self.chunk_duration = chunk_duration\n",
    "\n",
    "        self.p = pa.PyAudio()\n",
    "\n",
    "        device_info = self.p.get_host_api_info_by_index(0)\n",
    "        num_devices = device_info.get('deviceCount')\n",
    "        devices = {}\n",
    "        for i in range(0, num_devices):\n",
    "            if (self.p.get_device_info_by_host_api_device_index(0, i).get(\n",
    "                'maxInputChannels')) > 0:\n",
    "                devices[i] = self.p.get_device_info_by_host_api_device_index(\n",
    "                    0, i)\n",
    "\n",
    "        if (len(devices) == 0):\n",
    "            raise RuntimeError(\"Cannot find any valid input devices\")\n",
    "\n",
    "        if input_device_id is None or input_device_id not in \\\n",
    "            devices.keys():\n",
    "            print(\"\\nInput Devices:\")\n",
    "            for id, info in devices.items():\n",
    "                print(\"{}: {}\".format(id,info.get(\"name\")))\n",
    "            input_device_id = int(input(\"Enter device id to use: \"))\n",
    "\n",
    "        self.input_device_id = input_device_id\n",
    "        devinfo = self.p.get_device_info_by_index(input_device_id)\n",
    "        self.device_default_sr = int(devinfo['defaultSampleRate'])\n",
    "        print(\"Device sample rate: %d\" % self.device_default_sr)\n",
    "\n",
    "    def transcribe_audio(self, streaming=True):\n",
    "        ctx = InferContext(FLAGS.url, protocol, FLAGS.model_name, FLAGS.model_version,\n",
    "                    correlation_id=CORRELATION_ID, verbose=True,\n",
    "                    streaming=False)\n",
    "        \n",
    "        chunk_size = int(self.chunk_duration*self.device_default_sr)\n",
    "        self.recording_state = \"init\"\n",
    "\n",
    "        def keyboard_listener():\n",
    "            input(\"**********Press Enter to start and end transcribing...**********\")\n",
    "            self.recording_state = \"capture\"\n",
    "            print(\"Recording...\")\n",
    "            \n",
    "            input(\"\")\n",
    "            self.recording_state = \"release\"\n",
    "\n",
    "        listener = threading.Thread(target=keyboard_listener)\n",
    "        listener.start()\n",
    "\n",
    "        start = True\n",
    "        print(\"starting....\")\n",
    "        \n",
    "        stream_initialized = False\n",
    "        audio_signal = 0\n",
    "        audio_segment = 0\n",
    "        end = False\n",
    "        \n",
    "        cnt = 0\n",
    "        MAX_CHUNKS = 10\n",
    "        while self.recording_state != \"release\":\n",
    "            try:\n",
    "                if self.recording_state == \"capture\":\n",
    "\n",
    "                    if not stream_initialized:\n",
    "                        stream = self.p.open(\n",
    "                            format=pa.paInt16,\n",
    "                            channels=1,\n",
    "                            rate=self.device_default_sr,\n",
    "                            input=True,\n",
    "                            input_device_index=self.input_device_id,\n",
    "                            frames_per_buffer=chunk_size)\n",
    "                        stream_initialized = True\n",
    "\n",
    "                    # Read an audio chunk from microphone\n",
    "                    audio_signal = stream.read(chunk_size, exception_on_overflow = False)\n",
    "                    if self.recording_state == \"release\":\n",
    "                      break\n",
    "                      end = True\n",
    "                    audio_signal = np.frombuffer(audio_signal,dtype=np.int16)\n",
    "                    audio_segment = AudioSegment(audio_signal,\n",
    "                                                              self.device_default_sr,\n",
    "                                                              self.target_sr)\n",
    "                    \n",
    "                    if cnt == MAX_CHUNKS:\n",
    "                        end = True\n",
    "                    if cnt > 1:\n",
    "                        start = False\n",
    "                        \n",
    "                    # Inference\n",
    "                    flags = InferRequestHeader.FLAG_NONE\n",
    "                    x = (audio_segment.samples, self.target_sr, start, end)\n",
    "                    if x[2]:\n",
    "                        flags = flags | InferRequestHeader.FLAG_SEQUENCE_START\n",
    "                    if x[3]:\n",
    "                        flags = flags | InferRequestHeader.FLAG_SEQUENCE_END\n",
    "                    if not end:\n",
    "                        ctx.run({'WAV_DATA' : (x[0],),\n",
    "                                 'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(x[0]), dtype=np.int32),)},\n",
    "                                {},\n",
    "                                batch_size=1,\n",
    "                                flags=flags,\n",
    "                                corr_id=CORRELATION_ID)\n",
    "                    else:\n",
    "                        res = ctx.run({'WAV_DATA' : (x[0],),\n",
    "                                       'WAV_DATA_DIM' : (np.full(shape=1, fill_value=len(x[0]), dtype=np.int32),)},\n",
    "                                      { 'TEXT' : InferContext.ResultFormat.RAW },\n",
    "                                      batch_size=1,\n",
    "                                      flags=flags,\n",
    "                                      corr_id=CORRELATION_ID)\n",
    "                        print(\"\".join([x.decode('utf-8') for x in res['TEXT'][0]]))\n",
    "                    \n",
    "                    if cnt == MAX_CHUNKS: # reset server\n",
    "                        start = True\n",
    "                        end = False\n",
    "                        cnt = 0\n",
    "                    \n",
    "                    cnt += 1\n",
    "                    sys.stdout.write(\"\\r\" + \".\"*cnt)\n",
    "                    sys.stdout.flush()\n",
    "                    \n",
    "            except Exception as e:\n",
    "                print(e)\n",
    "                break\n",
    "\n",
    "        stream.close()\n",
    "        self.p.terminate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Device sample rate: 32000\n"
     ]
    }
   ],
   "source": [
    "transcriber = TranscribeFromMicrophone(input_device_id,\n",
    "    target_sr=FLAGS.sample_rate,\n",
    "    chunk_duration=FLAGS.chunk_duration)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After executing the below cell, upon pressing ENTER, the mic will start recording chunks of audio from the specified mic and stream them continuously to the server. After every 10 chunks, the client takes and display the results, while the status of the server is reset, i.e., it treats the next chunk as the start of a fresh new request. \n",
    "When pressing ENTER again, the client stops.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "transcriber.transcribe_audio()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "g8MxXY5GmTc8"
   },
   "source": [
    "# Conclusion\n",
    "\n",
    "In this notebook, we have walked through the complete process of preparing the audio data from a microphone and carry out inference with the Kaldi ASR model.\n",
    "\n",
    "## What's next\n",
    "Now it's time to try the Kaldi ASR model on your own data. The online client can also be further improved, for example, by detecting natural breaks in the input stream (e.g., silence) to break sentence more properly. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "249yGNLmmTc_"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "include_colab_link": true,
   "name": "TensorFlow_UNet_Industrial_Colab_train_and_inference.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
